import torch
import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math

import tensorly as tly

tly.set_backend('pytorch')


class Conv2d_tucker_vanilla(torch.nn.Conv2d):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=True,
                 dilation=1, start_rank_percent=0.4) -> None:
        """
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """
        super().__init__(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                         bias=bias, dilation=dilation)

        low_rank_percent = start_rank_percent
        self.dims = [self.out_channels, self.in_channels] + list(self.kernel_size)

        # make sure that there are at least 3 channels, for rgb images
        self.rank = [max(int(d * low_rank_percent), 3) for d in self.dims[:2]] + self.dims[2::]

        self.C = torch.nn.Parameter(torch.empty(size=self.rank), requires_grad=True)
        self.Us = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(size=(d, r)), requires_grad=True) for d, r in zip(self.dims, self.rank)])

        self.reset_tucker_parameters()  # parameter intitialization

    @torch.no_grad()
    def reset_tucker_parameters(self):
        torch.nn.init.kaiming_uniform_(self.C, a=math.sqrt(5))
        for i in range(len(self.dims)):
            torch.nn.init.kaiming_uniform_(self.Us[i], a=math.sqrt(5))

            # Orthonormalize bases
            self.Us[i], _ = torch.linalg.qr(self.Us[i], 'reduced')

    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """

        C = self.C[:self.rank[0], :self.rank[1], :self.rank[2], :self.rank[3]]
        Us = [U[:, :self.rank[i]] for i, U in enumerate(self.Us)]

        result = tltorch.functional.tucker_conv(input,
                                                tucker_tensor=tltorch.TuckerTensor(C, Us, rank=self.rank),
                                                bias=self.bias, stride=self.stride, padding=self.padding,
                                                dilation=self.dilation)
        # No bias!
        return result

    @torch.no_grad()
    def step(self, lr=0.05):

        for i in range(len(self.Us)):
            self.Us[i].data = self.Us[i].data - lr * self.Us[i].grad

        self.C.data = self.C.data - lr * self.C.grad

    @torch.no_grad()
    def get_r_mod_i(self, i):
        return min(self.rank[i], math.prod([r for j, r in enumerate(self.rank) if j != i]))

    @torch.no_grad()
    def set_grad_zero(self):
        for u in self.Us:
            u.grad.zero_()
        self.C.grad.zero_()
